import os
import torch
import cv2 as cv
import numpy as np
from torch.utils.data.dataset import Dataset


class ZhnDatasetClass(Dataset):
    def __init__(self):
        super().__init__()
        self.video = []
        pos_prepath = 'E:/dataset/instrument/pos'
        neg_prepath = 'E:/dataset/instrument/neg'
        pos_video = os.listdir(pos_prepath)
        neg_video = os.listdir(neg_prepath)
        for path in pos_video:
            self.video.append(VideoInfo(pos_prepath + '/' + path, 1))
        for path in neg_video:
            self.video.append(VideoInfo(neg_prepath + '/' + path, 0))

    def __len__(self):
        total_len = 0
        for i in range(len(self.video)):
            total_len += len(self.video[i])
        return total_len

    def __getitem__(self, item):
        for i in range(len(self.video)):
            if item < len(self.video[i]):
                return self.video[i][item]
            item -= len(self.video[i])


class ZhnDatasetDetect(Dataset):
    def __init__(self):
        super().__init__()
        self.img = []
        self.img.append(ImageInfo('E:/dataset/instrument/test1.png', 195, 373, 83, 355, 0.2))
        self.img.append(ImageInfo('E:/dataset/instrument/test2.png', 226, 420, 50, 340, 0.2))
        self.img.append(ImageInfo('E:/dataset/instrument/test3.png', 174, 320, 150, 350, 0.2))

    def __len__(self):
        total_len = 0
        for i in range(len(self.img)):
            total_len += len(self.img[i])
        return total_len

    def __getitem__(self, item):
        for i in range(len(self.img)):
            if item < len(self.img[i]):
                return self.img[i][item]
            item -= len(self.img[i])


class ImageInfo(Dataset):
    def __init__(self, image_path, x1, x2, y1, y2, kmin=0.2):
        super().__init__()
        self.image = cv.imread(image_path)  # h,w,3
        assert self.image is not None, 'Image does not exit.'
        self.kmin = kmin
        self.x1 = x1
        self.y1 = y1
        self.x = (x1 + x2) // 2
        self.y = (y1 + y2) // 2
        self.w = x2 - x1
        self.h = y2 - y1
        self.cntx = (640-self.w)//4
        self.cnty = (480-self.h)//4

    def __len__(self):
        return self.cntx * self.cnty * 10

    def __getitem__(self, item):
        move_x, move_y, k = matrix_index(item, [self.cntx, self.cnty])
        move_x *= 4
        move_y *= 4
        k = k / 10 + self.kmin
        x = np.float32(move_x + self.w//2)
        y = np.float32(move_y + self.h//2)
        m = np.float32([[k, 0, k*(move_x-self.x1)+x-k*x], [0, k, k*(move_y-self.y1)+y-k*y]])
        img = cv.warpAffine(self.image, m, (self.image.shape[1], self.image.shape[0]), borderValue=(255, 255, 255))
        w = self.w*k
        h = self.h*k
        # x1 = int(x - w/2 + 0.5)
        # y1 = int(y - h/2 + 0.5)
        # x2 = int(x + w/2 + 0.5)
        # y2 = int(y + h/2 + 0.5)
        # imgshow = cv.rectangle(img, (x1, y1), (x2, y2), color=(0, 0, 255), thickness=2)
        # cv.imshow('test', imgshow)
        # cv.waitKey(0)
        img = img.transpose(2, 0, 1)/256  # c,h,w
        return torch.tensor(img, dtype=torch.float32), torch.tensor([x, y, w, h])


class VideoInfo(Dataset):
    def __init__(self, video_path, label):
        self.video = cv.VideoCapture(video_path)
        assert self.video is not None, 'Image does not exit.'
        self.cntframe = int(self.video.get(cv.CAP_PROP_FRAME_COUNT))
        self.label = torch.tensor(label, dtype=torch.float32)

    def __len__(self):
        return self.cntframe

    def __getitem__(self, item):
        self.video.set(cv.CAP_PROP_POS_FRAMES, float(item))
        ret, img = self.video.read()
        assert ret is True
        # cv.imshow('test', img)
        # cv.waitKey(10)
        img = img.transpose(2, 0, 1)/256
        return torch.tensor(img, dtype=torch.float32), self.label


def matrix_index(cnt, side):
    """将cnt个单位方块按照底面边长为side的长方体摞起来,返回最后一个方块的坐标
    举例:matrix_index(314,[5,20])=[4,2,3]
    解释:314=5x20x3+5x2+4"""
    square = side[0] * side[1]
    z = cnt // square
    cnt -= square * z
    y = cnt // side[0]
    cnt -= y * side[0]
    x = cnt
    return [x, y, z]
